Gemma/[Gemma_2]Guess_the_word.ipynb (315 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "nH85BOCo7YYk" }, "source": [ "##### Copyright 2024 Google LLC." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "9tQNAByc7U9g" }, "outputs": [], "source": [ "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "F7r2q0wS7bxf" }, "source": [ "# Play with AI - Guess the word\n", "\n", "This cookbook illustrates how you can employ the instruction-tuned model version of Gemma as a chatbot to play \"Guess the word\" game.\n", "\n", "<table align=\"left\">\n", " <td>\n", " <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Guess_the_word.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", " </td>\n", "</table>" ] }, { "cell_type": "markdown", "metadata": { "id": "ZHrL4tqs7mYK" }, "source": [ "## Setup\n", "\n", "### Select the Colab runtime\n", "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n", "\n", "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n", "2. Select **Change runtime type**.\n", "3. Under **Hardware accelerator**, select **T4 GPU**.\n", "\n", "\n", "### Gemma setup on Kaggle\n", "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n", "\n", "* Get access to Gemma on kaggle.com.\n", "* Select a Colab runtime with sufficient resources to run the Gemma 2B model.\n", "* Generate and configure a Kaggle username and API key.\n", "\n", "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment." ] }, { "cell_type": "markdown", "metadata": { "id": "pQEE8RoO75F-" }, "source": [ "### Set environment variables\n", "\n", "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "XsY2Ut7a76Wa" }, "outputs": [], "source": [ "import os\n", "from google.colab import userdata\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"tensorflow\" or \"torch\".\n", "\n", "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", "# vars as appropriate for your system.\n", "os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n", "os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Ea_56Zpa78Gu" }, "source": [ "### Install dependencies\n", "\n", "Install Keras and KerasNLP." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AxPjbcnC79ck" }, "outputs": [], "source": [ "# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n", "!pip install -q -U keras-nlp\n", "!pip install -q -U keras" ] }, { "cell_type": "markdown", "metadata": { "id": "a_QCPQLf8OU0" }, "source": [ "### Create a chat helper to manage the conversation state" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2BmB5Zua8Vs0" }, "outputs": [], "source": [ "import re\n", "\n", "import keras\n", "import keras_nlp\n", "\n", "# Run at half precision to fit in memory\n", "keras.config.set_floatx(\"bfloat16\")\n", "\n", "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_instruct_2b_en\")\n", "gemma_lm.compile(sampler=\"top_k\")\n", "\n", "\n", "class ChatState():\n", " \"\"\"\n", " Manages the conversation history for a turn-based chatbot\n", " Follows the turn-based conversation guidelines for the Gemma family of models\n", " documented at https://ai.google.dev/gemma/docs/formatting\n", " \"\"\"\n", "\n", " __START_TURN_USER__ = \"<start_of_turn>user\\n\"\n", " __START_TURN_MODEL__ = \"<start_of_turn>model\\n\"\n", " __END_TURN__ = \"<end_of_turn>\\n\"\n", "\n", " def __init__(self, model, system=\"\"):\n", " \"\"\"\n", " Initializes the chat state.\n", "\n", " Args:\n", " model: The language model to use for generating responses.\n", " system: (Optional) System instructions or bot description.\n", " \"\"\"\n", " self.model = model\n", " self.system = system\n", " self.history = []\n", "\n", " def add_to_history_as_user(self, message):\n", " \"\"\"\n", " Adds a user message to the history with start/end turn markers.\n", " \"\"\"\n", " self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)\n", "\n", " def add_to_history_as_model(self, message):\n", " \"\"\"\n", " Adds a model response to the history with start/end turn markers.\n", " \"\"\"\n", " self.history.append(self.__START_TURN_MODEL__ + message)\n", "\n", " def get_history(self):\n", " \"\"\"\n", " Returns the entire chat history as a single string.\n", " \"\"\"\n", " return \"\".join([*self.history])\n", "\n", " def get_full_prompt(self):\n", " \"\"\"\n", " Builds the prompt for the language model, including history and system description.\n", " \"\"\"\n", " prompt = self.get_history() + self.__START_TURN_MODEL__\n", " if len(self.system)>0:\n", " prompt = self.system + \"\\n\" + prompt\n", " return prompt\n", "\n", " def send_message(self, message):\n", " \"\"\"\n", " Handles sending a user message and getting a model response.\n", "\n", " Args:\n", " message: The user's message.\n", "\n", " Returns:\n", " The model's response.\n", " \"\"\"\n", " self.add_to_history_as_user(message)\n", " prompt = self.get_full_prompt()\n", " response = self.model.generate(prompt, max_length=4096)\n", " result = response.replace(prompt, \"\") # Extract only the new response\n", " self.add_to_history_as_model(result)\n", " return result\n", "\n", " def show_history(self):\n", " for h in self.history:\n", " print(h)\n", "\n", "\n", "chat = ChatState(gemma_lm)" ] }, { "cell_type": "markdown", "metadata": { "id": "_1jyCoRd8EwX" }, "source": [ "## Play the game" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "zoWDt87V83rZ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Choose your theme: animal\n", "Guess what I'm thinking.\n", "Type \"quit\" if you want to quit.\n", "A playful, furry swimmer. Known for its playful antics and clever use of tools. A member of the weasel-like family with a distinctive, waterproof coat. \n", "<end_of_turn>\n", "\n", "> platypus\n", "A creature that spends its days by a river or lake, often seen splashing and diving with its distinctive, thick fur. It's known for its playful demeanor and ability to hold a surprisingly large amount of water in its paws. \n", "<end_of_turn>\n", "\n", "> beaver\n", "A small, aquatic mammal with a playful, curious nature, often spotted near water, known for its distinctive, waterproof fur and love for swimming. \n", "<end_of_turn>\n", "\n", "> otter\n", "Correct!\n" ] } ], "source": [ "theme = input(\"Choose your theme: \")\n", "setup_message = f\"Generate a random single word from {theme}.\"\n", "\n", "chat.history.clear()\n", "answer = chat.send_message(setup_message).split()[0]\n", "answer = re.sub(r\"\\W+\", \"\", answer) # excludes all numbers, letters and '_'\n", "chat.history.clear()\n", "cmd_exit = \"quit\"\n", "question = f'Describe the word \"{answer}\" without saying it.'\n", "\n", "resp = \"\"\n", "while resp.lower() != answer.lower() and resp != cmd_exit:\n", " text = chat.send_message(question)\n", " if resp == \"\":\n", " print(f'Guess what I\\'m thinking.\\nType \"{cmd_exit}\" if you want to quit.')\n", " remove_answer = re.compile(re.escape(answer), re.IGNORECASE)\n", " text = remove_answer.sub(\"XXXX\", text)\n", " print(text)\n", " resp = input(\"\\n> \")\n", "\n", "if resp == cmd_exit:\n", " print(f\"The answer was {answer}.\\n\")\n", "else:\n", " print(\"Correct!\")" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "[Gemma_2]Guess_the_word.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }